{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "LeNet(\n  (conv): Sequential(\n    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n    (1): Sigmoid()\n    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n    (4): Sigmoid()\n    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (fc): Sequential(\n    (0): Linear(in_features=256, out_features=120, bias=True)\n    (1): Sigmoid()\n    (2): Linear(in_features=120, out_features=84, bias=True)\n    (3): Sigmoid()\n    (4): Linear(in_features=84, out_features=10, bias=True)\n  )\n)\n"
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet,self).__init__()\n",
    "\n",
    "        self.conv = nn.Sequential(\n",
    "            nn.Conv2d(1,6,5),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2,2), # kernel_size ,stride = 2\n",
    "            nn.Conv2d(6,16,5),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2,2)\n",
    "        )\n",
    "\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(16 * 4 * 4 ,120),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120,84),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84,10)\n",
    "        )\n",
    "\n",
    "    def forward(self,img):\n",
    "        feature = self.conv(img)\n",
    "\n",
    "        output = self.fc(feature.view(img.shape[0],-1)) # Flatter\n",
    "        return output\n",
    "\n",
    "net = LeNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l \n",
    "import torch\n",
    "\n",
    "batch_size = 256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
    "\n",
    "def train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs):\n",
    "    net = net.to(device)\n",
    "    print(\"training on \",device)\n",
    "\n",
    "    loss = torch.nn.CrossEntropyLoss()\n",
    "    batch_count = 0\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time()\n",
    "\n",
    "        for x,y in train_iter:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            y_hat = net(x)\n",
    "            l = loss(y_hat,y)\n",
    "            optimizer.zero_grad()\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            train_l_sum += l.cpu().item()\n",
    "            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n",
    "\n",
    "            n += y.shape[0]\n",
    "            batch_count += 1\n",
    "        test_acc = d2l.evaluate_accuracy(test_iter,net)\n",
    "\n",
    "        print('epoch %d, loss %.4f, train acc %.3f,test acc %.3f,time %.1f' %(epoch + 1,train_l_sum / batch_count,train_acc_sum / n ,test_acc,time.time()-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 1.7925, train acc 0.339,test acc 0.587,time 3.8\nepoch 2, loss 0.4724, train acc 0.639,test acc 0.700,time 1.9\nepoch 3, loss 0.2546, train acc 0.719,test acc 0.731,time 1.9\nepoch 4, loss 0.1700, train acc 0.742,test acc 0.742,time 2.0\nepoch 5, loss 0.1242, train acc 0.758,test acc 0.757,time 1.9\n"
    }
   ],
   "source": [
    "lr,num_epochs = 0.001,5\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "\n",
    "train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}